import torch
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import re
import json
import os
import math

MODEL_PATH = "xxx"
TOKENIZER_PATH = MODEL_PATH
DATA_PATH = "xxx.jsonl"
OUTPUT_DIR = "xxx"
os.makedirs(OUTPUT_DIR, exist_ok=True)
BERT_MODEL_NAME = "bert-base-uncased"

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    BERT_DEVICE = "cuda:1"
    MAIN_DEVICE = "cuda:0"
else:
    BERT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    MAIN_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"BERT running on: {BERT_DEVICE}")
print(f"Main model running on: {MAIN_DEVICE}")

bert_tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
bert_model = AutoModel.from_pretrained(BERT_MODEL_NAME).to(BERT_DEVICE)
bert_model.eval()

@torch.no_grad()
def get_bert_embedding(texts):
    encoded = bert_tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(BERT_DEVICE)
    outputs = bert_model(**encoded)
    cls_embeddings = outputs.last_hidden_state[:, 0, :]
    return cls_embeddings.cpu()

def split_sentences(text):
    sentences = re.split(r'[\n。！？!?；;]', text)
    sentences = [s.strip() for s in sentences if s.strip()]
    return sentences

tokenizer = AutoTokenizer.from_pretrained(
    TOKENIZER_PATH, 
    use_fast=True,
    padding_side="left",
    model_max_length=8192,
    legacy=False
)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    max_position_embeddings=8192,
    rope_scaling=None
)
model.gradient_checkpointing_enable()

for param in model.parameters():
    param.requires_grad = True

if torch.cuda.device_count() > 2:
    ref_device = "cuda:2"
else:
    ref_device = MAIN_DEVICE

ref_model = create_reference_model(model)
ref_model = ref_model.to(ref_device)

def load_jsonl(path):
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            try:
                data.append(json.loads(line))
            except json.JSONDecodeError:
                print(f"Skipping invalid JSON: {line}")
    return data

dataset = load_jsonl(DATA_PATH)

ppo_config = PPOConfig(
    model_name=MODEL_PATH,
    learning_rate=1e-5,
    batch_size=4,
    mini_batch_size=2,
    gradient_accumulation_steps=1,
    optimize_cuda_cache=True,
    log_with=None,
    seed=42,
    ppo_epochs=4,
    target_kl=None,
    adap_kl_ctrl=False,
    early_stopping=False,
    use_score_scaling=True,
    use_score_norm=True,
    init_kl_coef=0.01,
    cliprange_value=0.2,
    vf_coef=1.0
)

ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    dataset=None,
)

def extract_actions(text):
    actions = set()
    pattern = r'(?:【|\[|\*|\(|Step\s*\d+[:\-]?)\s*([^\n】\]\*\)]+?)\s*(?:】|\]|\*|\)|:)'
    matches = re.findall(pattern, text, re.IGNORECASE)
    
    valid_action_keywords = {
        "decomposition", "reflection", "verification", "however", 
        "retry", "transition", "alternative", "answer", "verification"
    }
    
    for match in matches:
        clean_match = re.sub(r'[^\w\s]', '', match).strip().lower()
        for kw in valid_action_keywords:
            if kw in clean_match:
                actions.add(kw.capitalize())
                break
    
    for kw in valid_action_keywords:
        if re.search(rf"\b{kw}\b", text, re.IGNORECASE):
            actions.add(kw.capitalize())
    
    return actions

def calc_repetition_enhanced(text):
    sentences = split_sentences(text)
    if len(sentences) < 3:
        return 0.0
        
    repeated_pairs = 0
    total_pairs = 0
    embeddings = get_bert_embedding(sentences)
    
    for i in range(len(sentences)):
        for j in range(i+1, len(sentences)):
            words_i = set(sentences[i].lower().split())
            words_j = set(sentences[j].lower().split())
            lexical_sim = len(words_i & words_j) / max(len(words_i | words_j), 1)
            
            semantic_sim = torch.cosine_similarity(
                embeddings[i].unsqueeze(0), 
                embeddings[j].unsqueeze(0)
            ).item()
            
            similarity = 0.4 * lexical_sim + 0.6 * semantic_sim
            
            if similarity > 0.75:
                repeated_pairs += 1
            total_pairs += 1
    
    return repeated_pairs / max(total_pairs, 1) if total_pairs > 0 else 0.0

def compute_reward(response, sample):
    response = re.sub(r"\(0-1 as defined in.*$", "", response, flags=re.DOTALL).strip()
    response = re.sub(r"0-1 as defined in.*$", "", response, flags=re.DOTALL).strip()
    ref_label = sample["label"]
    task_type = sample.get("task_type", "sarcasm")
    length_type = sample.get("length_type", "Long")
    
    label_patterns = [
        r'^Label:\s*(\d+)',
        r'^Sarcasm Label:\s*(\d+)',
        r'^\*\*(\d+)\*\*',
        r'Therefore,\s+the\s+sarcasm\s+label\s+is:\s*\*\*(\d+)\*\*',
        r'Therefore,\s+the\s+sarcasm\s+label\s+is:\s*(\d+)',
        r'the\s+sarcasm\s+label\s+is:\s*\*\*(\d+)\*\*',
        r'the\s+sarcasm\s+label\s+is:\s*(\d+)',
        rf'{task_type} label\s*[:=]\s*(\d+)',
        r'Final Label\s*[:=]\s*(\d+)',
        r'label\s*[:=]\s*(\d+)',
        r'Answer\s*[:=]\s*(\d+)',
        r'Sarcasm Label:\s*(\d+)$',
        r'Label:\s*(\d+)$',
        r'\*\*(\d+)\*\*$',  
        r'\*\*Label:\s*(\d+)\*\*',  
        r'\*\*Sarcasm Label:\s*(\d+)\*\*'  
    ]
    
    pred_label = None
    
    for pattern in label_patterns:
        matches = re.findall(pattern, response, re.IGNORECASE)
        if matches:
            try:
                pred_label = int(matches[0])
                break
            except (ValueError, IndexError):
                continue
    
    if pred_label is None:
        asterisk_matches = re.findall(r'\*\*(\d+)\*\*(?:\s*\.?\s*)$', response)
        if asterisk_matches:
            try:
                pred_label = int(asterisk_matches[0])
            except (ValueError, IndexError):
                pass
    
    accuracy_weight = 0.70 if length_type == "Short" else 0.60
    
    if pred_label is None:
        accuracy_reward = -0.10
    elif pred_label == ref_label:
        accuracy_reward = accuracy_weight
    else:
        accuracy_reward = -accuracy_weight
    
    word_count = len(response.split())
    min_length = sample.get("min_length", 80)
    max_length = sample.get("max_length", 300)
    base_length = sample.get("base_length", 180)
    
    length_weight = 0.25 if length_type == "Short" else 0.15
    
    if word_count < min_length:
        length_reward = length_weight * (word_count / min_length)**2 
    elif word_count > max_length:
        excess_ratio = (word_count - max_length) / max_length
        length_reward = length_weight * math.exp(-4 * excess_ratio)
    else:
        range_half = 0.4 * (max_length - min_length)
        z = (word_count - base_length) / range_half
        length_reward = length_weight * math.exp(-0.5 * z**2)
    
    if length_type == "Short":
        structure_reward = 0.0
        action_score = 0.0
        connector_score = 0.0
        extracted_actions = []
    else:
        extracted_actions = extract_actions(response)
        valid_actions = {
            "Decomposition", "Reflection", "Verification", 
            "However", "Retry", "Transition", "Alternative", "Answer"
        }
        
        action_score = min(
            len([a for a in extracted_actions if a in valid_actions]) / 4.0, 
            1.0
        )
        
        connector_pattern = r'\b(because|therefore|thus|hence|however|but|although|then|so|since|consequently|furthermore|moreover|nevertheless|thus)\b'
        connectors = re.findall(connector_pattern, response, re.IGNORECASE)
        connector_score = min(len(connectors) / 5.0, 1.0)
        
        structure_reward = 0.20 * (0.6 * action_score + 0.4 * connector_score)
    
    repeat_ratio = calc_repetition_enhanced(response)
    repeat_penalty = -0.05 * min(repeat_ratio, 1.0)
    
    total_reward = (
        accuracy_reward +
        length_reward +
        structure_reward +
        repeat_penalty
    )
    total_reward = total_reward * 5.0
    
    debug_info = {
        "accuracy_reward": accuracy_reward,
        "length_reward": length_reward,
        "structure_reward": structure_reward,
        "repeat_penalty": repeat_penalty,
        "total_reward": total_reward,
        "pred_label": pred_label,
        "ref_label": ref_label,
        "word_count": word_count,
        "actions_detected": list(extracted_actions),
        "action_score": action_score,
        "connector_score": connector_score if length_type != "Short" else 0.0,
        "repeat_ratio": repeat_ratio,
        "length_type": length_type
    }
    return total_reward, debug_info

sarcasm_definition = "(1=sarcasm: contains features like surface praise with underlying criticism, contextual incongruity, exaggerated contrast, etc.0=not sarcasm)"
    
def save_model_with_config(model, tokenizer, output_dir):
    model.save_pretrained(
        output_dir,
        safe_serialization=True,
        max_shard_size="2GB"
    )
    tokenizer.save_pretrained(output_dir)
    config_path = os.path.join(output_dir, "config.json")
    if os.path.exists(config_path):
        with open(config_path, "r") as f:
            config = json.load(f)
        if "rope_scaling" in config:
            del config["rope_scaling"]
        config["max_position_embeddings"] = 8192
        with open(config_path, "w") as f:
            json.dump(config, f, indent=2)

reward_trends = {
    'epoch': [],
    'avg_reward': [],
    'max_reward': [],
    'min_reward': [],
    'accuracy': [],
    'avg_length': []
}

EPOCHS = 4
RESPONSES_PER_SAMPLE = 4

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    progress_bar = tqdm(dataset, desc=f"Epoch {epoch+1}")
    
    epoch_rewards = []
    epoch_accuracies = []
    epoch_lengths = []
    
    for idx, sample in enumerate(progress_bar):
        text_content = sample["prompt"]
        prompt = (
            f"### Sarcasm Classification Task ###\n"
            f"Instruction: Analyze the sarcasm of the text step by step.\n"
            f"Then output ONLY 'Label: [0-1]' at the very end.\n\n"
            f"Sarcasm label definitions (reference only, DO NOT output):\n"
            f"{sarcasm_definition}\n\n"
            f"Text content: {text_content}\n\n"
            f"Final label and reasoning steps:\n"
        )

        base_length = sample.get("base_length", 180)
        max_length = sample.get("max_length", 300)
        token_per_word = 1.7
        max_tokens = min(2048, int(max_length * token_per_word * 1.2))
        
        all_query_tensors = []
        all_response_tensors = []
        all_rewards = []
        all_responses = []
        reward_debugs = []
        
        for resp_idx in range(RESPONSES_PER_SAMPLE):
            if torch.cuda.device_count() > 3:
                generation_device = f"cuda:{3}"
            else:
                generation_device = MAIN_DEVICE
                
            query_tensor = tokenizer(
                prompt, 
                return_tensors="pt", 
                truncation=True, 
                max_length=8192,
                padding=True,
                add_special_tokens=True
            ).input_ids.to(generation_device)
            
            attention_mask = tokenizer(
                prompt, 
                return_tensors="pt", 
                truncation=True, 
                max_length=8192,
                padding=True,
                add_special_tokens=True
            ).attention_mask.to(generation_device)

            with torch.no_grad():
                gen_ids = model.generate(
                    input_ids=query_tensor,
                    attention_mask=attention_mask,
                    max_new_tokens=max_tokens,
                    do_sample=True,
                    top_p=0.9,
                    temperature=0.7,
                    pad_token_id=tokenizer.eos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    repetition_penalty=1.2,
                )
            
            output = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
            response = output[len(prompt):].strip() if output.startswith(prompt) else output
            response = re.sub(r"\(0-1 as defined in.*$", "", response, flags=re.DOTALL).strip()
            response = re.sub(r"0-1 as defined in.*$", "", response, flags=re.DOTALL).strip()
            
            reward, debug_info = compute_reward(response, sample)
            
            all_query_tensors.append(query_tensor.squeeze(0).cpu())
            response_tensor = gen_ids[0][query_tensor.shape[1]:].cpu()
            all_rewards.append(torch.tensor(reward, dtype=torch.float32))
            all_responses.append(response)
            reward_debugs.append(debug_info)
            
            del query_tensor, attention_mask, gen_ids, response_tensor
            torch.cuda.empty_cache()
        
        batch_size = 1
        for i in range(0, len(all_response_tensors), batch_size):
            batch_queries = [t.to(MAIN_DEVICE) for t in all_query_tensors[i:i+batch_size]]
            batch_responses = [t.to(MAIN_DEVICE) for t in all_response_tensors[i:i+batch_size]]
            batch_rewards = [r.to(MAIN_DEVICE) for r in all_rewards[i:i+batch_size]]
            
            try:
                stats = ppo_trainer.step(
                    queries=batch_queries,
                    responses=batch_responses,
                    scores=batch_rewards
                )
                for t in batch_queries + batch_responses:
                    del t
                for r in batch_rewards:
                    del r
                torch.cuda.empty_cache()
            except Exception as e:
                print(f"PPO step failed: {e}")
                torch.cuda.empty_cache()
                continue
        
        epoch_rewards.extend([r.item() for r in all_rewards])
        epoch_accuracies.extend([1 if d['pred_label'] == sample['label'] else 0 for d in reward_debugs])
        epoch_lengths.extend([d['word_count'] for d in reward_debugs])
        
        avg_reward = sum(r.item() for r in all_rewards) / len(all_rewards)
        accuracy = sum(1 for d in reward_debugs if d['pred_label'] == sample['label']) / len(reward_debugs)
        progress_bar.set_postfix({
            "avg_reward": f"{avg_reward:.2f}",
            "accuracy": f"{accuracy:.2f}",
            "max_tokens": max_tokens
        })
        
        if idx % 500 == 0:
            debug_data = {
                "prompt": prompt,
                "responses": all_responses,
                "rewards": [float(r.item()) for r in all_rewards],
                "debug_info": reward_debugs,
                "sample_metadata": {
                    "min_length": sample.get("min_length"),
                    "max_length": sample.get("max_length"),
                    "base_length": base_length,
                    "length_type": sample.get("length_type"),
                    "task_type": sample.get("task_type"),
                    "max_tokens": max_tokens
                }
            }
            
            with open(f"{OUTPUT_DIR}/rewards_debug_epoch{epoch}_step{idx}.json", "w") as f:
                json.dump(debug_data, f, indent=2)
            
            save_model_with_config(model, tokenizer, f"{OUTPUT_DIR}/checkpoint_epoch{epoch}_step{idx}")
    
    if epoch_rewards:
        reward_trends['epoch'].append(epoch+1)
        reward_trends['avg_reward'].append(np.mean(epoch_rewards))
        reward_trends['max_reward'].append(np.max(epoch_rewards))
        reward_trends['min_reward'].append(np.min(epoch_rewards))
        reward_trends['accuracy'].append(np.mean(epoch_accuracies))
        reward_trends['avg_length'].append(np.mean(epoch_lengths))
    
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Average Reward: {reward_trends['avg_reward'][-1]:.4f}")
    print(f"  Maximum Reward: {reward_trends['max_reward'][-1]:.4f}")
    print(f"  Minimum Reward: {reward_trends['min_reward'][-1]:.4f}")
    print(f"  Accuracy: {reward_trends['accuracy'][-1]:.4f}")
    print(f"  Average Length: {reward_trends['avg_length'][-1]:.1f} words")
    
    save_model_with_config(model, tokenizer, f"{OUTPUT_DIR}/ppo_epoch{epoch}_final")
    print(f"Epoch {epoch+1} completed. Model saved.")

print("\nTraining Trends Summary:")
print("Epoch | Avg Reward | Max Reward | Min Reward | Accuracy | Avg Length")
print("---------------------------------------------------------------")
for i in range(len(reward_trends['epoch'])):
    print(f"{reward_trends['epoch'][i]:5} | "
          f"{reward_trends['avg_reward'][i]:9.4f} | "
          f"{reward_trends['max_reward'][i]:9.4f} | "
          f"{reward_trends['min_reward'][i]:9.4f} | "
          f"{reward_trends['accuracy'][i]:8.4f} | "
          f"{reward_trends['avg_length'][i]:10.1f}")
